Data-Driven Prediction of Pennsylvania COVID-19 Cases through Google Search Trends

BMIN5030 Final Project

Author

Quan Minh Nguyen


1 Overview

This project leverages the Google COVID-19 Open Data Repository with a strong emphasis on Google search trend signals to visualize and predict COVID-19 case trajectories in Pennsylvania. By integrating search behavior with epidemiological, health, and vaccination data, the goal is to evaluate how real-time online queries can serve as an early-warning system for outbreaks and identify which search behaviors most strongly correlate with future surges in confirmed cases. In discussions with Dr. Kai Wang and lab manager Umair Ahsan, I learned that although the Google COVID-19 dataset offers powerful real-time surveillance potential, it suffers from substantial missingness and inconsistent reporting, requiring thoughtful feature selection and imputation techniques. Dr. Wang further emphasized the need to balance model accuracy with interpretability, especially when results are used to inform public health policy. Their guidance highlighted the broader scientific value of this work—bridging behavioral data science and epidemiology to explore how online search activity reflects real-world disease spread and can enhance pandemic preparedness.

https://github.com/quannguyenminh103/BMIN5030_Final_Project

2 Introduction

The COVID-19 pandemic dramatically reshaped daily life and public health systems, highlighting the need for better tools to anticipate and respond to infectious disease surges. Even as the pandemic has subsided, understanding the early signals that predict rising case numbers remains critical for future preparedness. This project focuses on using Google search trend data and other public features (weather, government policy, etc) from the Google COVID-19 Open Data Repository to visualize and forecast case and death patterns across U.S. regions. The core goal is to evaluate whether real-time search behavior—such as individuals looking up symptoms like loss of taste or breathing difficulties—can serve as an early indicator of community transmission before official reports catch up. By pairing search trends with supporting datasets on vaccination, mobility, and healthcare impact, I aim to develop a simple machine learning model that reveals which online behaviors most strongly track and predict shifts in the pandemic.

This project is inherently interdisciplinary, positioned at the intersection of epidemiology, data science, and behavioral informatics. Google search behavior reflects how people respond when they or those around them begin experiencing symptoms, offering a unique perspective into disease spread that complements clinical data. Through discussions with Umair Ahsan, I learned that this dataset contains substantial missing or noisy values, which requires careful cleaning and feature selection to ensure valid results. Dr. Kai Wang emphasized the importance of interpretability in public health models, so that predictive insights can be clearly understood and trusted by decision-makers. Their combined guidance shaped an approach centered on both data quality and explainability, using search trends not only to improve prediction accuracy but also to uncover how information-seeking behavior reflects real-world COVID-19 transmission dynamics.

3 Methods

3.0.1 Loading R Packages

library(tibble)
library(data.table)
library(tidyverse)
── Attaching core tidyverse packages ──────────────────────── tidyverse 2.0.0 ──
✔ dplyr     1.1.4     ✔ purrr     1.2.0
✔ forcats   1.0.1     ✔ readr     2.1.5
✔ ggplot2   4.0.1     ✔ stringr   1.5.1
✔ lubridate 1.9.4     ✔ tidyr     1.3.1
── Conflicts ────────────────────────────────────────── tidyverse_conflicts() ──
✖ dplyr::between()     masks data.table::between()
✖ dplyr::filter()      masks stats::filter()
✖ dplyr::first()       masks data.table::first()
✖ lubridate::hour()    masks data.table::hour()
✖ lubridate::isoweek() masks data.table::isoweek()
✖ dplyr::lag()         masks stats::lag()
✖ dplyr::last()        masks data.table::last()
✖ lubridate::mday()    masks data.table::mday()
✖ lubridate::minute()  masks data.table::minute()
✖ lubridate::month()   masks data.table::month()
✖ lubridate::quarter() masks data.table::quarter()
✖ lubridate::second()  masks data.table::second()
✖ purrr::transpose()   masks data.table::transpose()
✖ lubridate::wday()    masks data.table::wday()
✖ lubridate::week()    masks data.table::week()
✖ lubridate::yday()    masks data.table::yday()
✖ lubridate::year()    masks data.table::year()
ℹ Use the conflicted package (<http://conflicted.r-lib.org/>) to force all conflicts to become errors
library(lubridate)
library(dplyr)
library(zoo)

Attaching package: 'zoo'

The following objects are masked from 'package:data.table':

    yearmon, yearqtr

The following objects are masked from 'package:base':

    as.Date, as.Date.numeric
library(tidyr)
library(stringr)
library(scales)

Attaching package: 'scales'

The following object is masked from 'package:purrr':

    discard

The following object is masked from 'package:readr':

    col_factor
# Missingness & EDA
library(naniar)      # for missingness visualization
library(janitor)     # for tabulations, clean names

Attaching package: 'janitor'

The following objects are masked from 'package:stats':

    chisq.test, fisher.test
# Time-series and rolling windows
library(slider)      # for rolling 14-day features

# Modeling
library(glmnet)
Loading required package: Matrix

Attaching package: 'Matrix'

The following objects are masked from 'package:tidyr':

    expand, pack, unpack

Loaded glmnet 4.1-10
library(caret)
Loading required package: lattice

Attaching package: 'caret'

The following object is masked from 'package:purrr':

    lift
library(xgboost)

Attaching package: 'xgboost'

The following object is masked from 'package:dplyr':

    slice
library(randomForest)
randomForest 4.7-1.2
Type rfNews() to see new features/changes/bug fixes.

Attaching package: 'randomForest'

The following object is masked from 'package:dplyr':

    combine

The following object is masked from 'package:ggplot2':

    margin
library(forecast)
Registered S3 method overwritten by 'quantmod':
  method            from
  as.zoo.data.frame zoo 
# Spatial & maps
library(maps)

Attaching package: 'maps'

The following object is masked from 'package:purrr':

    map
library(gganimate)
library(transformr)  # smoother gganimate transitions
library(ggplot2)
library(usmap)
library(plotly)

Attaching package: 'plotly'

The following object is masked from 'package:xgboost':

    slice

The following object is masked from 'package:ggplot2':

    last_plot

The following object is masked from 'package:stats':

    filter

The following object is masked from 'package:graphics':

    layout
# For reproducibility
set.seed(123)
options(bitmapType = "cairo")
library(magick)
Linking to ImageMagick 7.1.0.37
Enabled features: fontconfig, freetype, ghostscript, lcms, pango, x11
Disabled features: cairo, fftw, heic, raw, rsvg, webp

3.0.2 Data Loading & Exploration

We used daily COVID-19 data for all U.S. states and territories from the Google COVID-19 Open Data Repository (GoogleCloudPlatform, 2022). Each state-level dataset was retrieved and merged into a unified data frame in R using data.table::fread() and rbindlist().

# state_abbrev <- c(
#   AL="Alabama", AK="Alaska", AZ="Arizona", AR="Arkansas", CA="California",
#   CO="Colorado", CT="Connecticut", DE="Delaware", DC="District of Columbia",
#   FL="Florida", GA="Georgia", HI="Hawaii", ID="Idaho", IL="Illinois",
#   IN="Indiana", IA="Iowa", KS="Kansas", KY="Kentucky", LA="Louisiana",
#   ME="Maine", MD="Maryland", MA="Massachusetts", MI="Michigan", MN="Minnesota",
#   MS="Mississippi", MO="Missouri", MT="Montana", NE="Nebraska", NV="Nevada",
#   NH="New Hampshire", NJ="New Jersey", NM="New Mexico", NY="New York",
#   NC="North Carolina", ND="North Dakota", OH="Ohio", OK="Oklahoma",
#   OR="Oregon", PA="Pennsylvania", RI="Rhode Island", SC="South Carolina",
#   SD="South Dakota", TN="Tennessee", TX="Texas", UT="Utah", VT="Vermont",
#   VA="Virginia", WA="Washington", WV="West Virginia", WI="Wisconsin", WY="Wyoming",
#   AS="American Samoa", GU="Guam", MP="Northern Mariana Islands",
#   PR="Puerto Rico", VI="Virgin Islands"
# )
# 
# dt_list <- list()
# 
# for (abbrev in names(state_abbrev)) {
#   url <- sprintf("https://storage.googleapis.com/covid19-open-data/v3/location/US_%s.csv", abbrev)
#   message("Downloading: ", abbrev, " -> ", url)
#   tryCatch({
#     dt <- fread(url)
#     dt[, `:=`(state_abbrev = abbrev, state_name = state_abbrev[abbrev])]
#     dt_list[[abbrev]] <- dt
#   }, error = function(e) {
#     warning("Failed to download for ", abbrev, ": ", conditionMessage(e))
#   })
# }
# 
# covid19_data <- rbindlist(dt_list, use.names = TRUE, fill = TRUE)
# 
# if ("date" %in% names(covid19_data)) {
#   covid19_data[, date := as.IDate(date)]
# }
# setcolorder(covid19_data, c("state_abbrev", "state_name", setdiff(names(all_states_dt), c("state_abbrev", "state_name"))))
# 
# fwrite(covid19_data, "covid19_open_data_US_all.csv")

Each record represents one date for a given location, containing variables such as new cases, new deaths, hospitalizations, vaccination rates, testing volumes, and mobility indices:

covid19_data <- read_csv('covid19_open_data_US_all.csv')
Rows: 55496 Columns: 615
── Column specification ────────────────────────────────────────────────────────
Delimiter: ","
chr   (22): state_abbrev, state_name, location_key, place_id, wikidata_id, d...
dbl  (592): aggregation_level, new_confirmed, new_deceased, new_recovered, n...
date   (1): date

ℹ Use `spec()` to retrieve the full column specification for this data.
ℹ Specify the column types or set `show_col_types = FALSE` to quiet this message.
covid19_data %>% filter(state_name == 'Pennsylvania')
# A tibble: 991 × 615
   state_abbrev state_name   location_key date       place_id        wikidata_id
   <chr>        <chr>        <chr>        <date>     <chr>           <chr>      
 1 PA           Pennsylvania US_PA        2020-01-01 ChIJieUyHiaALY… Q1400      
 2 PA           Pennsylvania US_PA        2020-01-02 ChIJieUyHiaALY… Q1400      
 3 PA           Pennsylvania US_PA        2020-01-03 ChIJieUyHiaALY… Q1400      
 4 PA           Pennsylvania US_PA        2020-01-04 ChIJieUyHiaALY… Q1400      
 5 PA           Pennsylvania US_PA        2020-01-05 ChIJieUyHiaALY… Q1400      
 6 PA           Pennsylvania US_PA        2020-01-06 ChIJieUyHiaALY… Q1400      
 7 PA           Pennsylvania US_PA        2020-01-07 ChIJieUyHiaALY… Q1400      
 8 PA           Pennsylvania US_PA        2020-01-08 ChIJieUyHiaALY… Q1400      
 9 PA           Pennsylvania US_PA        2020-01-09 ChIJieUyHiaALY… Q1400      
10 PA           Pennsylvania US_PA        2020-01-10 ChIJieUyHiaALY… Q1400      
# ℹ 981 more rows
# ℹ 609 more variables: datacommons_id <chr>, country_code <chr>,
#   country_name <chr>, subregion1_code <chr>, subregion1_name <chr>,
#   iso_3166_1_alpha_2 <chr>, iso_3166_1_alpha_3 <chr>,
#   aggregation_level <dbl>, new_confirmed <dbl>, new_deceased <dbl>,
#   new_recovered <dbl>, new_tested <dbl>, cumulative_confirmed <dbl>,
#   cumulative_deceased <dbl>, cumulative_recovered <dbl>, …

3.0.3 The Most Challenge Problem - Missing Data!

The dataset used in this study contains 55,496 observations and 615 features, covering COVID-19 data for 56 U.S. states and territories from January 1, 2020, to September 17, 2022. We will focus on Pennsylvania state in this project. Each observation corresponds to a specific date and location, and the dataset includes diverse variables such as epidemiological indicators (new and cumulative cases, deaths, testing rates), vaccination coverage, demographic information, public policy indices, mobility, Google search trends, and healthcare capacity metrics.

Initial data inspection revealed substantial missingness across features: only 505 of the 615 variables contained less than 60 % null values, suggesting a considerable missing-data problem. From manual checking, almost all features with more than 60% missing data are associated to age and sex splits (i.e. new_confirmed_case_0, etc). This is pretty reasonable because this is a high-level data repository; hence, it is not easy to keep tract all individual data. To improve data quality and model reliability, I manually reviewed each feature and removed variables with less than 40% completeness, reasoning that these variables likely contribute little predictive value.

# Compute proportion of missing for each column
miss_summary <- covid19_data %>%
  summarise(across(everything(), ~ mean(is.na(.)))) %>%
  pivot_longer(cols = everything(),
               names_to = "variable",
               values_to = "prop_missing") %>%
  arrange(desc(prop_missing))

head(miss_summary, 111)
# A tibble: 111 × 2
   variable                prop_missing
   <chr>                          <dbl>
 1 new_tested_age_0               0.988
 2 new_tested_age_1               0.988
 3 new_tested_age_2               0.988
 4 new_tested_age_3               0.988
 5 new_tested_age_4               0.988
 6 new_tested_age_5               0.988
 7 new_tested_male                0.988
 8 new_tested_female              0.988
 9 cumulative_tested_age_0        0.988
10 cumulative_tested_age_1        0.988
# ℹ 101 more rows
# Visualize missingness for top variables (just as illustration)
miss_summary %>%
  slice_max(prop_missing, n = 30) %>%
  ggplot(aes(x = reorder(variable, prop_missing), y = prop_missing)) +
  geom_col() +
  coord_flip() +
  labs(
    title = "Proportion of Missing Values by Variable",
    x = "Variable",
    y = "Proportion Missing"
  )

essential_cols <- c("date", "state_name", "state_abbrev")

vars_to_keep <- miss_summary %>%
  filter(prop_missing <= 0.6 | variable %in% essential_cols) %>%
  pull(variable) %>%
  unique()

covid19_reduced <- covid19_data %>%
  select(all_of(vars_to_keep))

3.0.4 Data Imputation

For the remaining features, missing values were handled according to the type and temporal behavior of the variable. For daily cases, deaths, and vaccinations, we assume no cases happened if there is no report. For categorical or policy-related variables (e.g., school closures, stay-at-home mandates), missing segments were filled forward within each state, based on the assumption that policy changes remain constant for consecutive days until officially modified. In addition, some Google search trends (interesting point in this dataset!) or new cases/deaths within each state will be imputed using linear interpolation, as they tend to vary smoothly over short time windows. I realized that linear interpolation (na.approx) only fills gaps between existing non-NA values. However, it cannot invent values at the start or end of time series where the entire segment is missing like in Google search trends. Therefore, missing values at the beginning of the search-trend series are imputed using the smallest observed value in that feature, while missing values at the end of the series are filled by carrying forward the most recent observation. This avoids implying an absence of public interest, since individuals may search for these conditions for reasons unrelated to COVID-19 even before the pandemic gained traction. There are three manufacturers providing COVID-19 vaccines in the dataset. Because we assume their effects are equivalent, we combine them into a single measure. This also reduces missingness, as vaccination data from multiple sources are aggregated.

# --------------------------------------------------
# count-like series (ICU, hospital, etc.)
# Rule:
#  1) if all NA → leave as is (we'll later drop all-NA columns)
#  2) early NA  → 0
#  3) middle NA → forward fill
#  4) tail NA   → backward fill
# --------------------------------------------------
fill_count_series <- function(x) {
  if (all(is.na(x))) return(x)
  
  y <- x
  idx <- which(!is.na(y))
  first <- idx[1]
  last  <- tail(idx, 1)
  
  # early missing → 0
  if (first > 1) {
    y[1:(first - 1)] <- 0
  }
  # forward LOCF for interior NAs
  y <- na.locf(y, na.rm = FALSE)
  # backward LOCF for trailing NAs
  y <- na.locf(y, na.rm = FALSE, fromLast = TRUE)
  y
}
# --------------------------------------------------
# smooth trend-like series (search, mobility, weather)
# Rule:
#  1) if all NA → keep NA
#  2) interpolate internal gaps
#  3) early NA  → first non-NA value
#  4) tail NA   → last non-NA value
# --------------------------------------------------
fill_trend_series <- function(x) {
  if (all(is.na(x))) return(x)
  
  y <- x
  # interpolate internal gaps
  y <- na.approx(y, na.rm = FALSE)
  idx <- which(!is.na(y))
  first <- idx[1]
  last  <- tail(idx, 1)
  
  # early → first value
  if (first > 1) {
    y[1:first] <- y[first]
  }
  # tail → last value
  if (last < length(y)) {
    y[last:length(y)] <- y[last]
  }
  y
}

# ======================================================
# Daily COVID & Vaccination Imputation
# ======================================================
covid19_reduced <- covid19_reduced %>%
  mutate(
    new_confirmed = if_else(new_confirmed < 0, NA_real_, new_confirmed),
    new_deceased  = if_else(new_deceased < 0, NA_real_, new_deceased)
  )
covid19_clean <- covid19_reduced %>%
  arrange(state_name, date) %>%
  group_by(state_name) %>%
  mutate(
    # Daily case/death: NA means no cases reported
    new_confirmed = na.approx(new_confirmed, na.rm = FALSE),
    new_deceased  = na.approx(new_deceased, na.rm = FALSE),
    new_confirmed = replace_na(new_confirmed, 0),
    new_deceased = replace_na(new_deceased, 0),
    # Daily vaccine doses per brand (fill with 0)
    across(starts_with("new_vaccine_doses_administered"),
           ~ replace_na(.x, 0)),
    
    # Daily persons fully vaccinated per brand + aggregate
    across(starts_with("new_persons_fully_vaccinated"),
           ~ replace_na(.x, 0)),
    
    # Some datasets have total new persons vaccinated
    new_persons_vaccinated = replace_na(new_persons_vaccinated, 0),
    
    # Total daily vaccine doses (for convenience)
    new_vaccinations_total =
      new_vaccine_doses_administered_pfizer +
      new_vaccine_doses_administered_moderna +
      new_vaccine_doses_administered_janssen
  ) %>%
  # ======================================================
  # Recompute ALL cumulative case/vaccine cols
  #    → ensures no NA if daily values exist
  # ======================================================
  mutate(
    # cases/deaths
    cumulative_confirmed = cumsum(new_confirmed),
    cumulative_deceased  = cumsum(new_deceased),
    
    # brand-specific persons fully vaccinated
    cumulative_persons_fully_vaccinated_pfizer  =
      cumsum(new_persons_fully_vaccinated_pfizer),
    cumulative_persons_fully_vaccinated_moderna =
      cumsum(new_persons_fully_vaccinated_moderna),
    cumulative_persons_fully_vaccinated_janssen =
      cumsum(new_persons_fully_vaccinated_janssen),
    
    # brand-specific doses
    cumulative_vaccine_doses_administered_pfizer  =
      cumsum(new_vaccine_doses_administered_pfizer),
    cumulative_vaccine_doses_administered_moderna =
      cumsum(new_vaccine_doses_administered_moderna),
    cumulative_vaccine_doses_administered_janssen =
      cumsum(new_vaccine_doses_administered_janssen),
    
    # aggregate vaccination cumulative
    cumulative_persons_fully_vaccinated =
      cumsum(new_persons_fully_vaccinated),
    cumulative_persons_vaccinated =
      cumsum(new_persons_vaccinated),
    cumulative_vaccine_doses_administered =
      cumsum(new_vaccine_doses_administered),
    cumulative_vaccinations_total =
      cumsum(new_vaccinations_total)
  ) %>%
  ungroup()

# ======================================================
# Weather interpolation (per state) → interpolate + edge fill
# ======================================================
covid19_clean <- covid19_clean %>%
  group_by(state_name) %>%
  mutate(
    across(
      matches("temperature|rainfall|humidity|dew_point|snowfall"),
      fill_trend_series
    )
  ) %>%
  ungroup()

# ======================================================
# Search Trends — interpolate + edge fill
# ======================================================
search_cols <- grep("^search_trends_", names(covid19_clean), value = TRUE)

covid19_clean <- covid19_clean %>%
  group_by(state_name) %>%
  arrange(date) %>%
  mutate(
    across(all_of(search_cols), fill_trend_series)
  ) %>%
  ungroup()

# ======================================================
# Policy variables — forward & backward filling
#     (no interpolation, just carry last policy level)
# ======================================================
policy_patterns <- c("closing","stringency","support","policy",
                     "public_transport","stay_at_home",
                     "facial_coverings","testing_policy")

policy_cols <- grep(paste(policy_patterns, collapse = "|"),
                    names(covid19_clean), value = TRUE)

covid19_clean <- covid19_clean %>%
  group_by(state_name) %>%
  mutate(
    across(all_of(policy_cols),
           ~ if (all(is.na(.x))) {
               .x
             } else {
               z <- na.locf(.x, na.rm = FALSE)
               z <- na.locf(z, na.rm = FALSE, fromLast = TRUE)
               z
             })
  ) %>%
  ungroup()

# ======================================================
# Hospital & Mobility — count-like / trend-like
#     Apply your rules:
#       1) early NA → 0  (for counts), baseline (for mobility)
#       2) internal NA → interpolation
#       3) tail NA → last value
# ======================================================
covid19_clean <- covid19_clean %>%
  group_by(state_name) %>%
  mutate(
    # ICU / hospital counts (use count-style fill)
    across(
      c(current_intensive_care_patients,
        current_hospitalized_patients,
        new_hospitalized_patients,
        cumulative_hospitalized_patients),
      fill_count_series
    ),
    
    # Mobility: trend-like (interpolate + edge fill)
    across(
      c(mobility_parks,
        mobility_transit_stations,
        mobility_retail_and_recreation,
        mobility_grocery_and_pharmacy,
        mobility_workplaces,
        mobility_residential),
      fill_trend_series
    )
  ) %>%
  ungroup()

# ======================================================
# Remove metadata columns, tidy ordering
# ======================================================
covid19_clean <- covid19_clean %>%
  select(-any_of(c(
    "location_key","place_id","wikidata_id","datacommons_id",
    "subregion1_code","subregion1_name","iso_3166_1_alpha_2",
    "iso_3166_1_alpha_3","aggregation_level","openstreetmap_id"
  ))) %>%
  relocate(state_name, state_abbrev, date)

Now, let’s check if there is any missing values in the remaining columns:

# Compute proportion of missing for each column
missing_by_state <- covid19_clean %>%
  group_by(state_name) %>%
  summarise(across(
    .cols = everything(),
    .fns = ~ mean(is.na(.)) * 100,
    .names = "pct_miss_{col}"
  )) %>%
  ungroup()

# Preview the first few rows
head(missing_by_state, 10)
# A tibble: 10 × 499
   state_name         pct_miss_state_abbrev pct_miss_date pct_miss_search_tren…¹
   <chr>                              <dbl>         <dbl>                  <dbl>
 1 Alabama                                0             0                      0
 2 Alaska                                 0             0                    100
 3 American Samoa                         0             0                    100
 4 Arizona                                0             0                      0
 5 Arkansas                               0             0                      0
 6 California                             0             0                      0
 7 Colorado                               0             0                      0
 8 Connecticut                            0             0                      0
 9 Delaware                               0             0                    100
10 District of Colum…                     0             0                    100
# ℹ abbreviated name: ¹​pct_miss_search_trends_photodermatitis
# ℹ 495 more variables: pct_miss_search_trends_shallow_breathing <dbl>,
#   pct_miss_search_trends_viral_pneumonia <dbl>,
#   pct_miss_search_trends_allergic_conjunctivitis <dbl>,
#   pct_miss_search_trends_burning_chest_pain <dbl>,
#   pct_miss_search_trends_polydipsia <dbl>,
#   pct_miss_search_trends_angular_cheilitis <dbl>, …

From the table above, it is clearly seen that all features have been fully imputed. Note that, there would be some features would be 100% missing for some states (as there is no information).

In addition, according to several studies, we include 7-day and 14-day lagged case/death features because COVID-19 transmission and symptom onset typically occur within one to two weeks after exposure. These lags help capture the delayed relationship between earlier infections and later outcomes such as new deaths or hospitalizations. Hence, recent transmissions within these windows are strongly indicative of current outbreak momentum. These features smooth day-to-day noise and provide a more reliable signal of ongoing outbreak severity than daily counts alone, ultimately supporting improved forecasting performance. The following visualizations show that 7-day lagged case/death shows some consistent trends with new confirmed cases/deaths, indicating powerful predictive factors. Starting in April 2021, COVID-19 vaccinations became available to all adults, leading to a steady increase in vaccination uptake over time.

# ======================================================
# Feature Engineering for ML
#     - lag14, lag7 as features
#     - target = new_confirmed / new_deceased
# ======================================================
covid19_ml <- covid19_clean %>%
  group_by(state_name) %>%
  arrange(date) %>%
  mutate(
    # Lag features (early days → 0)
    lag_cases_7   = replace_na(lag(new_confirmed, 7), 0),
    lag_deaths_7  = replace_na(lag(new_deceased, 7), 0),
    # Explicit targets for clarity
    target_cases  = new_confirmed,
    target_deaths = new_deceased
  ) %>%
  ungroup()

national_ts <- covid19_ml %>%
  group_by(date) %>%
  summarise(
    total_new_cases   = sum(new_confirmed, na.rm = TRUE),
    total_new_deaths  = sum(new_deceased, na.rm = TRUE),
    new_vax_tot = sum(new_vaccinations_total, na.rm = TRUE),
    lag_cases_7  = sum(replace_na(lag(new_confirmed, 7), 0), na.rm = TRUE),
    lag_deaths_7  = sum(replace_na(lag(new_deceased, 7), 0), na.rm = TRUE)
  ) %>% ungroup

# New cases with 14-day rolling sum
ggplot(national_ts, aes(x = date)) +
  geom_col(aes(y = total_new_cases), fill = "grey70") +
  geom_line(aes(y = lag_cases_7), linewidth = 0.7) +
  labs(title = "US Daily New Cases with 7-Day Lag Sum",
       x = "Date", y = "Cases") +
  theme_minimal()

# New deaths with 14-day rolling sum
ggplot(national_ts, aes(x = date)) +
  geom_col(aes(y = total_new_deaths), fill = "grey70") +
  geom_line(aes(y = lag_deaths_7), linewidth = 0.7) +
  labs(title = "US Daily New Deaths with 7-Day Lag Sum",
       x = "Date", y = "Deaths") +
  theme_minimal()

# Total vaccinations over time
ggplot(national_ts, aes(x = date, y = cumsum(new_vax_tot))) +
  geom_line() +
  labs(title = "US Cumulative Vaccinations (All Brands)",
       x = "Date", y = "Cumulative doses") +
  theme_minimal()

3.0.5 COVID-19 Cumulative Cases/Deaths in the US over time (GIF)

To better understand the geographic distribution and severity of the COVID-19 pandemic across the United States, we visualize cumulative and peak COVID-19 outcomes using state-level choropleth maps. The code first adds the geographic dataset (usmap) with federal FIPS identifiers, identifies the most recent reporting date, and determines the national peak dates for new cases and new deaths. We then generate four maps: cumulative confirmed cases, cumulative deaths as of the latest date, and peak daily new cases and deaths on their respective peak days. These maps allow us to assess both the overall burden of disease and the most intense periods of transmission, highlighting spatial disparities in how the pandemic unfolded. By examining current cumulative totals alongside historical peaks, we can better interpret trends, compare state-to-state impacts, and understand which regions experienced the most severe health outcomes. It is clearly seen that California, Texas, Florida, and New York were the most impacted regions by COVID-19. The eastern side seems to have more infected cases/deaths compared to western US. In Pennsylvania, on the peak case date (2022-01-10), there are about 29,854 infected persons while the death rate was up to 99 people on 2021-01-12. Until 2022-09-17, there are approximately 3,227,209 patients with infected COVID-19 and 47,300 deaths.

add_fips <- function(df) {
  df %>%
    mutate(
      fips = usmap::fips(state = state_name),
      fips = case_when(
        state_name == "American Samoa"           ~ "60",
        state_name == "Guam"                     ~ "66",
        state_name == "Northern Mariana Islands" ~ "69",
        state_name == "Virgin Islands"           ~ "78",
        TRUE ~ as.character(fips)
      )
    )
}
covid_map_data <- covid19_ml %>%
  group_by(state_name) %>%
  ungroup() %>%
  add_fips()

last_date <- max(covid_map_data$date, na.rm = TRUE)

peak_case_day <- national_ts %>%
  slice_max(total_new_cases, n = 1, with_ties = FALSE) %>%
  pull(date)

peak_death_day <- national_ts %>%
  slice_max(total_new_deaths, n = 1, with_ties = FALSE) %>%
  pull(date)
custom_colors <- c(
  "#FFFBC8", # very light yellow
  "#FECC5C", # light orange
  "#FD8D3C", # orange
  "#E31A1C", # red
  "#B10026"  # dark red (max)
)

## Cumulative Cases
plot_data_cases <- covid_map_data %>% filter(date == last_date) %>%
  mutate(
    hover_text = paste0(
      state_name, "<br>",
      "Cases: ", scales::comma(cumulative_confirmed), "<br>",
      "Date: ", last_date
    )
  )

fig_cases <- plot_ly(
  plot_data_cases,
  type = "choropleth",
  locations = ~state_abbrev,
  locationmode = "USA-states",
  z = ~cumulative_confirmed,
  text = ~hover_text,
  hoverinfo = "text",
  colorscale = custom_colors,
  reversescale = FALSE,
  colorbar = list(title = "Cases")
) 
fig_cases <- fig_cases %>%
  layout(
    title = list(text = paste("Total Cumulative COVID-19 Cases (as of", last_date, ")")),
    geo = list(
      scope = "usa",
      projection = list(type = "albers usa"),
      showlakes = TRUE,
      lakecolor = "white"
    )
  )
fig_cases
## Cumulative Deaths
plot_data_deaths <- covid_map_data %>%  filter(date == last_date) %>% 
  mutate(
    hover_text = paste0(
      state_name, "<br>",
      "Deaths: ", scales::comma(cumulative_deceased), "<br>",
      "Date: ", last_date
    )
  )

fig_deaths <- plot_ly(
  plot_data_deaths,
  type = "choropleth",
  locations = ~state_abbrev,
  locationmode = "USA-states",
  z = ~cumulative_deceased,
  text = ~hover_text,
  hoverinfo = "text",
  colorscale = custom_colors,
  reversescale = FALSE,
  colorbar = list(title = "Deaths")
) 
fig_deaths <- fig_deaths %>%
  layout(
    title = list(text = paste("Total Cumulative COVID-19 Deaths (as of", last_date, ")")),
    geo = list(
      scope = "usa",
      projection = list(type = "albers usa"),
      showlakes = TRUE,
      lakecolor = "white"
    )
  )
fig_deaths
# Peak Cases
peak_cases_map <- covid_map_data %>%
  filter(date == peak_case_day)
plot_peak_cases <- peak_cases_map %>%
  mutate(
    hover_text = paste0(
      state_name, "<br>",
      "New Cases: ", scales::comma(new_confirmed), "<br>",
      "Date: ", peak_case_day
    )
  )

fig_peak_cases <- plot_ly(
  plot_peak_cases,
  type = "choropleth",
  locations = ~state_abbrev,
  locationmode = "USA-states",
  z = ~new_confirmed,
  text = ~hover_text,
  hoverinfo = "text",
  colorscale = custom_colors,
  reversescale = FALSE,
  colorbar = list(title = "New Cases")
) %>%
  layout(
    title = paste("Peak Daily New COVID-19 Cases (", peak_case_day, ")", sep=""),
    geo = list(scope = "usa")
  )

fig_peak_cases
# Peak Deaths
peak_deaths_map <- covid_map_data %>%
  filter(date == peak_death_day)
plot_peak_deaths <- peak_deaths_map %>%
  mutate(
    hover_text = paste0(
      state_name, "<br>",
      "New Deaths: ", scales::comma(new_deceased), "<br>",
      "Date: ", peak_death_day
    )
  )

fig_peak_deaths <- plot_ly(
  plot_peak_deaths,
  type = "choropleth",
  locations = ~state_abbrev,
  locationmode = "USA-states",
  z = ~new_deceased,
  text = ~hover_text,
  hoverinfo = "text",
  colorscale = custom_colors,
  reversescale = FALSE,
  colorbar = list(title = "New Deaths")
) %>%
  layout(
    title = paste("Peak Daily COVID-19 Deaths (", peak_death_day, ")", sep=""),
    geo = list(scope = "usa")
  )

fig_peak_deaths

Additionally, to visualize the spatiotemporal progression of COVID-19 in the United States, monthly aggregated new cases and deaths were animated using a yellow-to-red epidemiological gradient. These animations reveal distinct pandemic waves, including initial coastal outbreaks, the nationwide winter 2020 surge, and later variant-driven resurgences, with clear differences in timing and severity across regions.

# Aggregate cumulative totals by month
monthly_cum <- covid19_ml %>%
  mutate(month = floor_date(date, "month")) %>%
  group_by(state_name, state_abbrev, month) %>%
  summarise(total_cases = max(cumulative_confirmed, na.rm = TRUE)) %>%
  ungroup() %>%
  add_fips

p_cum_cases <- plot_usmap(
  data = monthly_cum,
  values = "total_cases",
  regions = "states"
) +
  scale_fill_gradient(
    low = "#FFFBC8",
    high = "red",
    name = "Cumulative Cases",
    labels = comma
  ) +
  labs(
    title = "Cumulative COVID-19 Cases by State",
    subtitle = 'Month: {format(frame_time, "%b %Y")}'
  ) +
  theme(legend.position = "right") +
  transition_time(month) +
  ease_aes("linear")

gif_cum_cases <- animate(
  p_cum_cases,
  nframes = length(unique(monthly_cum$month)),
  fps = 10,
  width = 1200,
  height = 700,
  renderer = magick_renderer()
)

anim_save("covid_monthly_cumulative_cases.gif", gif_cum_cases)

monthly_cum_deaths <- covid19_ml %>%
  mutate(month = floor_date(date, "month")) %>%
  group_by(state_name, state_abbrev, month) %>%
  summarise(total_deaths = max(cumulative_deceased, na.rm = TRUE)) %>%
  ungroup() %>%
  mutate(
    fips = usmap::fips(state = state_name),
    fips = case_when(
      state_name == "American Samoa"           ~ "60",
      state_name == "Guam"                     ~ "66",
      state_name == "Northern Mariana Islands" ~ "69",
      state_name == "Virgin Islands"           ~ "78",
      TRUE ~ as.character(fips)
    )
  )

p_cum_deaths <- plot_usmap(
  data = monthly_cum_deaths,
  values = "total_deaths",
  regions = "states"
) +
  scale_fill_gradient(
    low = "#FFFBC8",
    high = "red",
    name = "Cumulative Deaths",
    labels = comma
  ) +
  labs(
    title = "Cumulative COVID-19 Deaths by State",
    subtitle = 'Month: {format(frame_time, "%b %Y")}'
  ) +
  theme(legend.position = "right") +
  transition_time(month) +
  ease_aes("linear")

gif_cum_deaths <- animate(
  p_cum_deaths,
  nframes = length(unique(monthly_cum_deaths$month)),
  fps = 10,
  width = 1200,
  height = 700,
  renderer = magick_renderer()
)

anim_save("covid_monthly_cumulative_deaths.gif", gif_cum_deaths)

4 Results

4.1 Data Preparation

To prepare for modeling, we split the data into training, validation, and testing sets based on the timeline of the pandemic. Since the major surge occurred in 2021 and early 2022, we use all data up to September 30, 2021 for training. The following months through December 31, 2021 serve as a validation period for tuning hyperparameters, while data from January to June 2022 is reserved for final testing to evaluate generalization performance during the peak (after June 2022, lots of features become NaN). Model performance is assessed using Root Mean Squared Error (RMSE) and Mean Absolute Error (MAE).

cols_exclude_common <- c(
  # identifiers
  "state_name", "state_abbrev", 
  "country_code", "country_name", "new_confirmed", "new_deceased",
  
  # target leakage
  "cumulative_confirmed", "cumulative_deceased"
)
cols_exclude_state <- c(
  "population", "area_sq_km", "life_expectancy",
  "latitude", "longitude", "elevation_m",
  "population_male", "population_female",
  grep("population_age_", names(covid19_ml), value = TRUE)
)
# Temporal cutoff dates
train_end <- as.Date("2021-09-30")
val_end <- as.Date("2021-12-31")
test_end  <- as.Date("2022-06-15")
build_state_ml <- function(state, target) {
  df <- covid19_ml %>%
    filter(state_name == state) %>%
    arrange(date) %>%
    
    # Remove identifiers and leakage
    select(
      -all_of(cols_exclude_common),
      -all_of(cols_exclude_state)
    )
  
  # Remove completely NA columns
  df <- df[, colSums(!is.na(df)) > 0, drop = FALSE]
  df <- df %>% arrange(date)
  # Remove early flat zeros in target (meaning no outbreak yet)
  first_pos <- which(df[[target]] > 0)[1]
  df <- df[first_pos:nrow(df), ]
  non_na_cols <- names(df)[colSums(!is.na(df)) > 0]
  df <- df[, non_na_cols, drop = FALSE]
  # Identify constant columns in training
  constant_cols <- names(df)[
    sapply(df, function(x) length(unique(x)) <= 1)
  ]
  df <- df %>% select(-all_of(constant_cols))
  df <- df %>%
    mutate(across(
      where(is.list),
      ~ suppressWarnings(as.numeric(.x))
    ))
  return(df)
}

## some states may not have certain features (so we need to remove them)
prepare_xy_state <- function(state, target) {
  
  df <- build_state_ml(state, target) %>%
    arrange(date)
  
  # Split sets
  train_df <- df %>% filter(date <= train_end)
  val_df <- df %>% filter(date >  train_end & date <= val_end)
  test_df  <- df %>% filter(date >  val_end & date <= test_end)

  # Extract predictors and target
  X_train <- train_df %>% select(-all_of(c('target_cases', 'target_deaths', 'date')))
  y_train <- train_df[[target]]
  
  X_val <- val_df %>% select(-all_of(c('target_cases', 'target_deaths', 'date')))
  y_val <- val_df[[target]]
  
  X_test  <- test_df %>% select(-all_of(c('target_cases', 'target_deaths', 'date')))
  y_test  <- test_df[[target]]
  
  return(list(
    X_train = as.matrix(X_train),
    y_train = y_train,
    X_val = as.matrix(X_val),
    y_val = y_val,
    X_test = as.matrix(X_test),
    y_test = y_test,
    features = colnames(X_train)
  ))
}

rmse <- function(pred, obs) sqrt(mean((pred - obs)^2))
mae <- function(pred, obs) abs(mean((pred - obs)))

4.2 Pennsylvania Modeling Example

This dataset contains 56 U.S. states and territories. Therefore, we will train and evaluate separate models for each region using a loop. In this section, we present Pennsylvania as an example to demonstrate our modeling approach for predicting new confirmed cases and deaths.

state_name = 'Pennsylvania'
xy <- prepare_xy_state(state_name, 'target_cases')
X_train <- xy$X_train
y_train <- xy$y_train
X_val   <- xy$X_val
y_val   <- xy$y_val
X_test  <- xy$X_test
y_test  <- xy$y_test

4.2.1 Feature Selection - LASSO

Before training the models, we perform feature selection because the dataset contains more than 500 predictors, which may introduce noise and negatively impact model performance. Reducing the number of features helps improve both efficiency and accuracy. To identify the most relevant predictors, we apply LASSO, a widely used regularization technique that performs automatic feature selection by shrinking less informative coefficients to zero.

cat("\n--- LASSO ---\n")

--- LASSO ---
results <- list()
alphas <- seq(0, 1, by=0.2)
best_rmse <- Inf
best <- NULL

for(a in alphas){
  fit_cv <- cv.glmnet(X_train, y_train, nfolds=5, alpha=a)
  pred_val <- predict(fit_cv, X_val, s = "lambda.min")
  score <- rmse(pred_val, y_val)
  
  if(score < best_rmse){
    best_rmse <- score
    best <- list(model = fit_cv, alpha = a)
    best_pred_val <- pred_val
  }
}

final_lasso <- glmnet(
  X_train, y_train,
  alpha = best$alpha,
  lambda = best$model$lambda.min
)
pred_test <- predict(final_lasso, X_test)

# Extract coefficients
co <- as.matrix(coef(final_lasso))[-1, 1]   # drop intercept

# Ranked list of important features (non-zero coefficients)
important_features <- names(co[co != 0][order(-abs(co[co != 0]))])

results$lasso <- list(
  test_rmse = rmse(pred_test, y_test),
  test_mae  = mae(pred_test, y_test),
  pred_val = as.numeric(best_pred_val),
  pred_test = as.numeric(pred_test),
  best_params = list(alpha = best$alpha, lambda = best$model$lambda.min),
  feature_importance = important_features
)
cat("Test RMSE:", results$lasso$test_rmse, "\n")
Test RMSE: 4928.946 
cat("Test MAE :", results$lasso$test_mae, "\n")
Test MAE : 1469.883 
X_train_sel <- X_train[, important_features, drop = FALSE]
X_val_sel <- X_val[, important_features, drop = FALSE]
X_test_sel  <- X_test[, important_features, drop = FALSE]

The list of most important features to predict COVID-19 new confirmed cases in Pennsylvania is:

important_features
 [1] "search_trends_dysgeusia"                      
 [2] "search_trends_neck_mass"                      
 [3] "search_trends_chancre"                        
 [4] "search_trends_ageusia"                        
 [5] "search_trends_burning_chest_pain"             
 [6] "search_trends_hyperemesis_gravidarum"         
 [7] "search_trends_eye_pain"                       
 [8] "search_trends_hyperventilation"               
 [9] "search_trends_delayed_onset_muscle_soreness"  
[10] "search_trends_onychorrhexis"                  
[11] "search_trends_vertigo"                        
[12] "search_trends_diabetic_ketoacidosis"          
[13] "search_trends_thrombocytopenia"               
[14] "search_trends_food_craving"                   
[15] "search_trends_round_ligament_pain"            
[16] "search_trends_hypoxemia"                      
[17] "search_trends_panic_attack"                   
[18] "search_trends_biliary_colic"                  
[19] "search_trends_leg_cramps"                     
[20] "search_trends_halitosis"                      
[21] "search_trends_dystonia"                       
[22] "search_trends_nosebleed"                      
[23] "search_trends_beaus_lines"                    
[24] "search_trends_coma"                           
[25] "search_trends_epiphora"                       
[26] "search_trends_renal_colic"                    
[27] "search_trends_tremor"                         
[28] "search_trends_facial_nerve_paralysis"         
[29] "search_trends_pus"                            
[30] "search_trends_low_back_pain"                  
[31] "search_trends_intracranial_pressure"          
[32] "search_trends_chest_pain"                     
[33] "search_trends_abdominal_obesity"              
[34] "cancel_public_events"                         
[35] "search_trends_gastroesophageal_reflux_disease"
[36] "search_trends_red_eye"                        
[37] "search_trends_hypertrophy"                    
[38] "search_trends_anal_fissure"                   
[39] "search_trends_hiccup"                         
[40] "search_trends_meningitis"                     
[41] "search_trends_stroke"                         
[42] "search_trends_asperger_syndrome"              
[43] "search_trends_compulsive_hoarding"            
[44] "search_trends_rhinitis"                       
[45] "new_hospitalized_patients"                    
[46] "mobility_workplaces"                          
[47] "mobility_grocery_and_pharmacy"                
[48] "current_intensive_care_patients"              
[49] "lag_cases_7"                                  
[50] "current_hospitalized_patients"                
[51] "minimum_temperature_celsius"                  
[52] "new_persons_fully_vaccinated_janssen"         

4.3 Machine Learning Modeling

In this section, we evaluate the performance of four predictive models: LASSO Regression, Linear Regression, Random Forest, and XGBoost. The results for LASSO Regression were presented earlier. For each model (except Linear Regression), we also perform hyperparameter tuning to optimize performance before evaluating them on the test dataset.

4.3.1 Linear Regression

cat("\n--- Linear Regression ---\n")

--- Linear Regression ---
lm_fit <- lm(y_train ~ ., data = as.data.frame(X_train_sel))
pred_val <- predict(lm_fit, as.data.frame(X_val_sel))
pred_test <- predict(lm_fit, as.data.frame(X_test_sel))

lm_imp <- broom::tidy(lm_fit) %>%
  rename(feature = term, coef = estimate) %>%
  arrange(desc(abs(coef)))
important_features <- lm_imp %>%
  filter(feature != "(Intercept)") %>%
  pull(feature)
results$linear <- list(
  test_rmse = rmse(pred_test, y_test),
  test_mae  = mae(pred_test, y_test),
  pred_val  = pred_val,
  pred_test = pred_test,
  best_params = list(model = "OLS-baseline"),
  feature_importance = important_features
)
cat("Test RMSE:", results$linear$test_rmse, "\n")
Test RMSE: 5029.937 
cat("Test MAE :", results$linear$test_mae, "\n")
Test MAE : 1625.584 

4.3.2 Random Forest

cat("\n--- Random Forest ---\n")

--- Random Forest ---
rf_grid <- expand.grid(
  ntree = c(200, 400, 600),
  mtry  = c(
    floor(sqrt(ncol(X_train_sel))),
    floor(ncol(X_train_sel)/3),
    floor(ncol(X_train_sel)/5)
  ),
  maxnodes = c(NULL, 20, 40)
)

best_rmse <- Inf
best_rf <- NULL

for(i in 1:nrow(rf_grid)){
  fit <- randomForest(
    X_train_sel, y_train,
    ntree = rf_grid$ntree[i],
    mtry = rf_grid$mtry[i],
    maxnodes = rf_grid$maxnodes[i]
  )
  pred_val <- predict(fit, X_val_sel)
  score <- rmse(pred_val, y_val)
  
  if(score < best_rmse){
    best_rmse <- score
    best_rf <- list(model=fit, hp=rf_grid[i,], pred_val=pred_val)
  }
}

pred_test <- predict(best_rf$model, X_test_sel)

imp <- data.frame(
  feature = rownames(importance(best_rf$model)),
  importance = importance(best_rf$model)[,1]
) %>% arrange(desc(importance))
important_features_rf <- imp %>%
  pull(feature)
results$rf <- list(
  test_rmse = rmse(pred_test, y_test),
  test_mae  = mae(pred_test, y_test),
  pred_val  = best_rf$pred_val,
  pred_test = pred_test,
  best_params = as.list(best_rf$hp),
  feature_importance = important_features_rf
)
cat("Test RMSE:", results$rf$test_rmse, "\n")
Test RMSE: 6579.478 
cat("Test MAE :", results$rf$test_mae, "\n")
Test MAE : 2905.084 

4.3.3 XGBoost

cat("\n--- XGBoost ---\n")

--- XGBoost ---
dtrain <- xgb.DMatrix(X_train_sel, label=y_train)
dval   <- xgb.DMatrix(X_val_sel,   label=y_val)
dtest  <- xgb.DMatrix(X_test_sel,  label=y_test)

xgb_grid <- expand.grid(
  eta = c(0.03, 0.05, 0.1),
  max_depth = c(4, 6, 10),
  min_child_weight = c(1, 3),
  subsample = c(0.7, 0.9),
  colsample_bytree = c(0.7, 0.9)
)

best_rmse <- Inf
best_xgb <- NULL

for(i in 1:nrow(xgb_grid)){
  params <- list(
    objective="reg:squarederror",
    eval_metric="rmse",
    eta = xgb_grid$eta[i],
    max_depth = xgb_grid$max_depth[i],
    min_child_weight = xgb_grid$min_child_weight[i],
    subsample = xgb_grid$subsample[i],
    colsample_bytree = xgb_grid$colsample_bytree[i]
  )
  
  fit <- xgb.train(params, dtrain, nrounds=350, verbose=0)
  pred_val <- predict(fit, dval)
  score <- rmse(pred_val, y_val)
  
  if(score < best_rmse){
    best_rmse <- score
    best_xgb <- list(model=fit, hp=xgb_grid[i,], pred_val=pred_val)
  }
}

pred_test <- predict(best_xgb$model, dtest)
imp_xgb <- xgb.importance(
  model = best_xgb$model,
  feature_names = colnames(X_train_sel)
)
important_features <- imp_xgb %>%
  arrange(desc(Gain)) %>%      # Gain is the main importance metric
  pull(Feature)
results$xgb <- list(
  test_rmse = rmse(pred_test, y_test),
  test_mae  = mae(pred_test, y_test),
  pred_val  = best_xgb$pred_val,
  pred_test = pred_test,
  best_params = as.list(best_xgb$hp),
  feature_importance = important_features
)
cat("Test RMSE:", results$xgb$test_rmse, "\n")
Test RMSE: 6277.436 
cat("Test MAE :", results$xgb$test_mae, "\n")
Test MAE : 2631.464 
df <- build_state_ml(state_name, 'target_cases') %>%
  arrange(date)  %>%
  filter(date <= test_end)
train_df <- df %>% filter(date <= train_end)
val_df <- df %>% filter(date >  train_end & date <= val_end)
test_df  <- df %>% filter(date >  val_end & date <= test_end)
# --- Prepare plotting dataframe ---
df_truth <- df %>%
  select(date, target_cases) %>%
  rename(value = target_cases) %>%
  left_join(df %>% select(date) %>%
              mutate(phase = case_when(
                date <= train_end ~ "Training",
                date > train_end & date <= val_end ~ "Validation",
                date > val_end & date <= test_end ~ "Test"
              )),
            by = "date") %>%
  mutate(type = "Observed")

get_pred_df <- function(model_name, results_list) {
  val_len  <- length(results_list[[model_name]]$pred_val)
  test_len <- length(results_list[[model_name]]$pred_test)

  bind_rows(
    data.frame(
      date = val_df$date[seq_len(val_len)],
      value = results_list[[model_name]]$pred_val,
      model = model_name,
      phase = "Validation",
      type = "Prediction"
    ),
    data.frame(
      date = test_df$date[seq_len(test_len)],
      value = results_list[[model_name]]$pred_test,
      model = model_name,
      phase = "Test",
      type = "Prediction"
    )
  )
}

model_names <- c("linear", "rf", "xgb","lasso")

df_preds <- bind_rows(lapply(model_names, get_pred_df, results_list = results))

# Filter data only to validation + testing period (no training)
df_truth_sub <- df_truth %>%
  filter(date > train_end & date <= test_end)

df_preds_sub <- df_preds %>%
  filter(date <= test_end)

y_max <- max(df_truth_sub$value, na.rm = TRUE)

p <- ggplot() +
  # Observed FIRST → beneath models
  geom_line(data = df_truth_sub,
            aes(x = date, y = value),
            color = "black", size = 1.2, alpha = 0.7) +
  
  # Predictions SECOND → ON TOP
  geom_line(data = df_preds_sub,
            aes(x = date, y = value, color = model),
            linewidth = 1.4, alpha = 0.9) +

  geom_vline(xintercept = as.numeric(val_end),
             linetype = "dashed", color = "gray40", linewidth = 0.8) +

  annotate("text", x = val_end - 20, y = y_max,
           label = "Validation", hjust = 1, size = 5, fontface = "bold") +
  annotate("text", x = val_end + 20, y = y_max,
           label = "Testing", hjust = 0, size = 5, fontface = "bold") +

  scale_color_manual(
    values = c(
      "linear" = "#33a02c",
      "rf"     = "#e31a1c",
      "xgb"    = "#ff7f00",
      "lasso"  = "#1f78b4"
    ),
    name = "Predictive Model"
  ) +

  labs(
    title = paste("COVID-19 Confirmed Case Predictions —", state_name),
    x = "Date",
    y = "New Confirmed Cases"
  ) +
  
  theme_minimal(base_size = 16) +
  theme(
    plot.title = element_text(size = 16, face = "bold"),
    plot.subtitle = element_text(size = 12),
    plot.margin = ggplot2::margin(20, 20, 40, 20, unit = "pt"),
    legend.box.margin = ggplot2::margin(10, 0, 0, 0, unit = "pt"),
    legend.margin = ggplot2::margin(5, 5, 5, 5, unit = "pt")
  )
p
Warning in scale_x_date(): A <numeric> value was passed to a Date scale.
ℹ The value was converted to a <Date> object.

ggsave(
  filename = paste0("covid_predictions_", state_name, ".png"),
  plot = p,
  width = 10,
  height = 6,
  dpi = 300
)
Warning in scale_x_date(): A <numeric> value was passed to a Date scale.
ℹ The value was converted to a <Date> object.
p
Warning in scale_x_date(): A <numeric> value was passed to a Date scale.
ℹ The value was converted to a <Date> object.

Let’s find out what are the most important features (contributing the COVID-19 predictive cases) shared by all models (ranked by mean ranking):

model_names <- names(results)  # c("lasso", "linear", "rf", "xgb")

rank_list <- purrr::map(model_names, function(m) {
  feats <- results[[m]]$feature_importance
  data.frame(
    feature = feats,
    rank = seq_along(feats),
    model = m
  )
})

# Combine into one long table
rank_df <- bind_rows(rank_list)

# Compute mean rank (lower = more important)
feature_rank_summary <- rank_df %>%
  group_by(feature) %>%
  summarise(
    mean_rank = mean(rank),
    sd_rank = sd(rank)  # helpful to see rank stability
  ) %>%
  arrange(mean_rank)

# Top 10 global features
top20_features <- feature_rank_summary %>%
  slice(1:20)

top20_features
# A tibble: 20 × 3
   feature                                     mean_rank sd_rank
   <chr>                                           <dbl>   <dbl>
 1 search_trends_ageusia                            6       3.37
 2 search_trends_dysgeusia                          9.25   12.8 
 3 search_trends_hypoxemia                         12.2     8.18
 4 search_trends_burning_chest_pain                14       8.60
 5 search_trends_leg_cramps                        15       6.38
 6 search_trends_hyperemesis_gravidarum            17.2    13.4 
 7 search_trends_onychorrhexis                     17.8    10.3 
 8 search_trends_vertigo                           18.5     8.89
 9 search_trends_delayed_onset_muscle_soreness     18.8    10.8 
10 search_trends_diabetic_ketoacidosis             19.8     9.00
11 search_trends_panic_attack                      20.2     9.71
12 search_trends_biliary_colic                     21      15.7 
13 search_trends_facial_nerve_paralysis            21      12.6 
14 search_trends_pus                               21.5     9.47
15 search_trends_thrombocytopenia                  21.5     8.19
16 search_trends_neck_mass                         22      23.6 
17 search_trends_low_back_pain                     23      13.3 
18 new_hospitalized_patients                       23.5    26.0 
19 search_trends_food_craving                      23.8    10.1 
20 search_trends_abdominal_obesity                 24.2    12.7 
top20_features$feature
 [1] "search_trends_ageusia"                      
 [2] "search_trends_dysgeusia"                    
 [3] "search_trends_hypoxemia"                    
 [4] "search_trends_burning_chest_pain"           
 [5] "search_trends_leg_cramps"                   
 [6] "search_trends_hyperemesis_gravidarum"       
 [7] "search_trends_onychorrhexis"                
 [8] "search_trends_vertigo"                      
 [9] "search_trends_delayed_onset_muscle_soreness"
[10] "search_trends_diabetic_ketoacidosis"        
[11] "search_trends_panic_attack"                 
[12] "search_trends_biliary_colic"                
[13] "search_trends_facial_nerve_paralysis"       
[14] "search_trends_pus"                          
[15] "search_trends_thrombocytopenia"             
[16] "search_trends_neck_mass"                    
[17] "search_trends_low_back_pain"                
[18] "new_hospitalized_patients"                  
[19] "search_trends_food_craving"                 
[20] "search_trends_abdominal_obesity"            

4.3.4 Evaluations & Discussion

From the model evaluation results, it is evident that LASSO Regression and Linear Regression consistently outperform Random Forest and XGBoost in predicting new confirmed COVID-19 cases during both the validation and testing periods (with lowest RMSE and MAE). A likely reason is that we restricted the feature set for all models to the subset selected by LASSO. While this ensures consistency, it may also prevent tree-based models from accessing other potentially informative features, ultimately limiting their predictive performance. In addition, tree-based models such as Random Forest and XGBoost may be more prone to overfitting in this setting, particularly given the limited number of real predictive signals and the presence of high-dimensional correlated features. Furthermore, linear models benefit from strong feature regularization (LASSO) or coefficient shrinkage, which helps reduce noise and prevents the model from relying on unstable or spurious predictors. The tree-based models show more volatility in the testing phase, suggesting weaker robustness to unseen data distribution shifts — especially around major surges in early 2022 when case dynamics changed. Regardless of the modeling approach, the testing errors remain relatively high (thousands of cases mismatched), highlighting the challenges of accurately predicting COVID-19 case trends with the current feature set. Although data imputation can help address missing values, it may not fully capture the true underlying patterns. Moreover, infection rates are influenced by numerous additional factors—such as human behavior, local policy interventions, immunity variability, and healthcare accessibility—that are weakly represented or entirely absent in this dataset. As a result, while these models provide valuable exploratory insights, the lack of comprehensive and reliable data makes it difficult to draw strong conclusions or generate highly accurate forecasts.

The most important features consistently identified across all four models were Google search trends related to common COVID-19 symptoms. Searches for loss of taste and smell (ageusia, dysgeusia), hypoxemia, chest pain, muscle soreness, and dizziness strongly align with core clinical presentations of the disease, while trends related to thrombocytopenia, gastrointestinal issues, panic attacks, and neurological symptoms reflect its broader systemic effects. These patterns may arise as individuals begin experiencing symptoms themselves or observe them in people nearby, prompting online searches for explanations or remedies even before seeking medical care. In this way, symptom-related search activity can reveal emerging transmission within a community earlier than official reporting channels. The prominence of these digital symptom indicators suggests that real-time search behavior may serve as an early predictor of emerging outbreaks, capturing shifts in disease burden before they are fully reflected in clinical case reporting.

Interestingly, we also found that some of the top-ranked search trends are not typically recognized as COVID-19 symptoms but may still provide valuable predictive insight. For example, searches for leg cramps, neck mass, biliary colic, abdominal obesity, and panic attacks may reflect hidden effects of infection such as clotting abnormalities, lymph node swelling, metabolic changes, or heightened anxiety and are increasingly observed in post-acute COVID-19 cases. Their strong contribution to forecasting suggests that public search behavior may capture early or indirect responses to COVID-19 transmission that are not yet fully understood, pointing toward potential avenues for future clinical and epidemiologic discovery.

4.3.5 Limitations and Future Directions

This study has a few key limitations. The dataset contains substantial missing and noisy values, particularly early (and around Jun-Sept 2022) in the pandemic, which may impact prediction accuracy even after preprocessing. Results were demonstrated using only one state example, so broader generalization remains uncertain. Additionally, the regression-based framework assumes relatively stable linear relationships between predictors and case trajectories, whereas real-world epidemic dynamics are highly non-linear and influenced by rapid shifts in policy, variants, and human behavior.

Future work should extend the modeling approach across multiple states to better evaluate regional differences. More robust imputation and feature engineering strategies could improve data quality and predictive power. Incorporating additional contextual information, such as mobility, public health interventions, and variant prevalence, and exploring more advanced time-series forecasting methods may help capture the complex temporal behavior of the pandemic. Ensemble approaches may also offer a better balance between performance and interpretability.

5 Conclusion

Overall, this work demonstrates that relatively simple machine learning models, particularly LASSO and Linear Regression, can provide reliable short-term predictions of new COVID-19 cases at the state level when supported by careful preprocessing and feature selection. Although more complex models like XGBoost and Random Forest hold promise, they require more data and additional techniques to avoid overfitting and cope with shifting pandemic conditions. This study provides a strong foundation for more scalable and robust epidemic prediction frameworks, with promising potential for refinement through improved data quality, enhanced temporal modeling, and deployment across multiple geographic regions.